热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

评测|CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测|CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

作者:Max Woolf

机器之心编译

参与:Jane W、吴攀

Keras 是由 François Chollet 维护的深度学习高级开源框架,它的底层基于构建生产级质量的深度学习模型所需的大量设置和矩阵代数。Keras API 的底层基于像 Theano 或谷歌的 TensorFlow 的较低级的深度学习框架。Keras 可以通过设置 flag 自由切换后端(backend)引擎 Theano/TensorFlow;而不需要更改前端代码。

虽然谷歌的 TensorFlow 已广受关注,但微软也一直在默默地发布自己的机器学习开源框架。例如 LightGBM 框架,可以作为著名的 xgboost 库的替代品。例如几周前发布的 CNTK v2.0(Microsoft Cognitive Toolkit),它与 TensorFlow 相比,显示出在准确性和速度方面的强劲性能。参阅机器之心报道《开源 | 微软发行 Cognitive Toolkit 2.0 完整版:从性能更新到应用案例》。

CNTK v2.0 还有一个关键特性:兼容 Keras。就在上周,对 CNTK 后端的支持被合并到官方的 Keras 资源库(repository)中。

Hacker News 论坛对于 CNTK v2.0 也有评论(https://news.ycombinator.com/item?id=14470967),微软员工声称,将 Keras 的后端由 TensorFlow 改为 CNTK 可以显著提升性能。那么让我们来检验这句话的真伪吧。

在云端进行深度学习

在云端设置基于 GPU 的深度学习实例令人惊讶地被忽视了。大多数人建议使用亚马逊 AWS 服务,它包含所有可用的 GPU 驱动,只需参照固定流程(https://blog.keras.io/running-jupyter-notebooks-on-gpu-on-aws-a-starter-guide.html)设置远程操作。然而,对于 NVIDIA Tesla K80 GPU,亚马逊 EC2 收费 $0.90/小时(不按时长比例收费);对于相同的 GPU,谷歌 Compute Engine(GCE)收费 $0.75/小时(按分钟比例收费),这对于需要训练许多小时的深度学习模型是非常显著的弱点。

要使用 GCE,你必须从一个空白的 Linux 实例中设置深度学习的驱动和框架。我使用 Keras 进行了第一次尝试(http://minimaxir.com/2017/04/char-embeddings/),但这并不有趣。不过,我最近受到 Durgesh Mankekar 文章(https://medium.com/google-cloud/containerized-jupyter-notebooks-on-gpu-on-google-cloud-8e86ef7f31e9)的启发,该文章采用了 Docker 容器这种更现代的方法来管理依赖关系,该文章还介绍了名为 Dockerfile 的安装脚本和容器与 Keras 必需的深度学习驱动/框架。Docker 容器可以使用 nvidia-docker 进行加载,这可以让 Docker 容器访问主机上的 GPU。在容器中运行深度学习脚本只需运行 Docker 命令行。当脚本运行完后,会自动退出容器。这种方法恰巧保证了每次执行是独立的;这为基准评估/重复执行提供了理想的环境。

我稍微调整了 Docker 容器(GitHub 网址 https://github.com/minimaxir/keras-cntk-docker),容器安装了 CNTK、与 CNTK 兼容的 Keras 版本,并设置 CNTK 为 Keras 的默认后端。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

基准方法

Keras 的官方案例(https://github.com/fchollet/keras/tree/master/examples)非常全面,涉及多种现实中的深度学习问题,并能完美地模拟 Keras 在不同模型的性能。我选取了强调不同神经网络架构的几个例子(https://github.com/minimaxir/keras-cntk-benchmark/tree/master/test_files),并添加了一个自定义 logger,它能够输出含有模型性能和训练时间进程的 CSV 文件。

如前所述,只需要设置一个 flag 就能方便地切换后端引擎。即使 Docker 容器中 Keras 的默认后端是 CNTK,一个简单的 -e KERAS_BACKEND ='tensorflow' 命令语句就可以切换到 TensorFlow。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

我写了一个 Python 基准脚本(https://github.com/minimaxir/keras-cntk-benchmark/blob/master/keras_cntk_benchmark.py)(在主机上运行)来管理并运行 Docker 容器中的所有例子,它同时支持 CNTK 和 TensorFlow 后端,并用 logger 收集生成的日志。

下面是不同数据集的结果。

IMDb 评论数据集

IMDb 评论数据集(http://ai.stanford.edu/~amaas/data/sentiment/)是用于情感分析的著名的自然语言处理(NLP)基准数据集。数据集中的 25000 条评论被标记为「积极」或「消极」。在深度学习成为主流之前,优秀的机器学习模型在测试集上达到大约 88% 的分类准确率。

第一个模型方法(imdb_bidirectional_lstm.py)使用了双向 LSTM(Bidirectional LSTM),它通过词序列对模型进行加权,同时采用向前(forward)传播和向后(backward)传播的方法。

首先,我们来看一下在训练模型时的不同时间点测试集的分类准确率:

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

通常,准确率随着训练的进行而增加;双向 LSTM 需要很长时间来训练才能得到改进的结果,但至少这两个框架都是同样有效的。

为了评估算法的速度,我们可以计算训练一个 epoch 所需的平均时间。每个 epoch 的时间大致相同;测量结果真实平均值用 95%的置信区间表示,这是通过非参数统计的 bootstrapping 方法得到的。双向 LSTM 的计算速度:

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

哇,CNTK 比 TensorFlow 快很多!虽然没有比 LSTM 的基准测试(https://arxiv.org/abs/1608.07249)快 5-10 倍,但是仅通过设置后端 flag 就几乎将运行时间减半就已经够令人震惊了。

接下来,我们用同样的数据集测试 fasttext 方法(imdb_fasttext.py)。fasttext 是一种较新的算法,可以计算词向量嵌入(word vector Embedding)的平均值(不论顺序),但是即使在使用 CPU 时也能得到令人难以置信的速度和效果,如同 Facebook 官方对 fasttext 的实现(https://github.com/facebookresearch/fastText)一样。(对于此基准,我倾向于使用二元语法模型/bigram)

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

由于模型简单,这两种框架的准确率几乎相同,但在使用词嵌入的情况下,TensorFlow 速度更快。(不管怎样,fasttext 明显比双向 LSTM 方法快得多!)此外,fasttext 打破了 88%的基准,这可能值得考虑在其它机器学习项目中推广。

MNIST 数据集

MNIST 数据集(http://yann.lecun.com/exdb/mnist/)是另一个著名的手写数字数据集,经常用于测试计算机视觉模型(60000 个训练图像,10000 个测试图像)。一般来说,良好的模型在测试集上可达到 99%以上的分类准确率。

多层感知器(multilayer perceptron/MLP)方法(mnist_mlp.py)仅使用一个大型全连接网络,就达到深度学习魔术(Deep Learning Magic™)的效果。有时候这样就够了。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

这两个框架都能极速地训练模型,每个 epoch 只需几秒钟;在准确性方面没有明确的赢家(尽管没有打破 99%),但是 CNTK 速度更快。

另一种方法(mnist_cnn.py)是卷积神经网络(CNN),它利用相邻像素之间的固有关系建模,是一种逻辑上更贴近图像数据的架构。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

在这种情况下,TensorFlow 在准确率和速度方面都表现更好(同时也打破 99%的准确率)。

CIFAR-10

现在来研究更复杂的实际模型,CIFAR-10 数据集(https://www.cs.toronto.edu/~kriz/cifar.html)是用于 10 个不同对象的图像分类的数据集。基准脚本的架构(cifar10_cnn.py)是很多层的 Deep CNN + MLP,其架构类似于著名的 VGG-16(https://gist.github.com/baraldilorenzo/07d7802847aaad0a35d3)模型,但更简单,由于大多数人没有用来训练的超级计算机集群。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

在这种情况下,两个后端的在准确率和速度上的性能均相等。也许 CNTK 更利于 MLP,而 TensorFlow 更利于 CNN,两者的优势互相抵消。

尼采文本生成

基于 char-rnn(https://github.com/karpathy/char-rnn)的文本生成(lstm_text_generation.py)很受欢迎。具体来说,它使用 LSTM 来「学习」文本并对新文本进行抽样。在使用随机的尼采文集(https://s3.amazonaws.com/text-datasets/nietzsche.txt)作为源数据集的 Keras 例子中,该模型尝试使用前 40 个字符预测下一个字符,并尽量减少训练的损失函数值。理想情况的是损失函数值低于 1.00,并且生成的文本语法一致。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

两者的损失函数值随时间都有相似的变化(不幸的是,1.40 的损失函数值下,仍有乱码文本生成),由于 LSTM 架构,CTNK 的速度更快。

对于下一个基准测试,我将不使用官方的 Keras 示例脚本,而是使用我自己的文本生成器架构(text_generator_keras.py),详见之前关于 Keras 的文章(http://minimaxir.com/2017/04/char-embeddings)。

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

我的网络避免了过早收敛,对于 TensorFlow,只需损失很小的训练速度;不幸的是,CNTK 的速度比简单模型慢了许多,但在高级模型中仍然比 TensorFlow 快得多。

以下是用 TensorFlow 训练的我的架构模型生成的文本输出:

hinks the rich man must be wholly perverity and connection of the english sin of the philosophers of the basis of the same profound of his placed and evil and exception of fear to plants to me such as the case of the will seems to the will to be every such a remark as a primates of a strong of [...]

这是用 CNTK 训练的模型输出:

(_x2js1hevjg4z_?z_aæ?q_gpmj:sn![?(f3_ch=lhw4y n6)gkh kujau momu,?!ljë7g)k,!?[45 0as9[d.68éhhptvsx jd_næi,ä_z!cwkr"_f6ë-mu_(epp [...]

等等,什么?显然,我的模型架构导致 CNTK 在预测时遇到错误,而「CNTK+简单的 LSTM」架构并没有发生这种错误。通过质量评估,我发现批归一化(batch normalization)是错误的原因,并及时提出了这个问题(https://github.com/Microsoft/CNTK/issues/1994)。

结论

综上,评价 Keras 框架是否比 TensorFlow 更好,这个判断并没有设想中的那么界限分明。两个框架的准确性大致相同。CNTK 在 LSTM/MLP 上更快,TensorFlow 在 CNN/词嵌入(Embedding)上更快,但是当网络同时实现两者时,它们会打个平手。

撇开随机错误,有可能 CNTK 在 Keras 上的运行还没有完全优化(实际上,1bit-SGD 的设置不起作用(https://github.com/Microsoft/CNTK/issues/1975)),所以未来还是有改进的空间的。尽管如此,简单地设置 flag 的效果是非常显著的,在将它们部署到生产之前,值得在 CNTK 和 TensorFlow 后端上测试 Keras 模型,以比较两者哪个更好。  评测 | CNTK在Keras上表现如何?能实现比TensorFlow更好的深度学习吗?

原文链接:http://minimaxir.com/2017/06/keras-cntk/

版权声明

本文仅代表作者观点,不代表百度立场。

阅读量: 0

0

0


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 我们


推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文介绍了Hyperledger Fabric外部链码构建与运行的相关知识,包括在Hyperledger Fabric 2.0版本之前链码构建和运行的困难性,外部构建模式的实现原理以及外部构建和运行API的使用方法。通过本文的介绍,读者可以了解到如何利用外部构建和运行的方式来实现链码的构建和运行,并且不再受限于特定的语言和部署环境。 ... [详细]
  • Nginx使用AWStats日志分析的步骤及注意事项
    本文介绍了在Centos7操作系统上使用Nginx和AWStats进行日志分析的步骤和注意事项。通过AWStats可以统计网站的访问量、IP地址、操作系统、浏览器等信息,并提供精确到每月、每日、每小时的数据。在部署AWStats之前需要确认服务器上已经安装了Perl环境,并进行DNS解析。 ... [详细]
  • 本文介绍了lua语言中闭包的特性及其在模式匹配、日期处理、编译和模块化等方面的应用。lua中的闭包是严格遵循词法定界的第一类值,函数可以作为变量自由传递,也可以作为参数传递给其他函数。这些特性使得lua语言具有极大的灵活性,为程序开发带来了便利。 ... [详细]
  • GetWindowLong函数
    今天在看一个代码里头写了GetWindowLong(hwnd,0),我当时就有点费解,靠,上网搜索函数原型说明,死活找不到第 ... [详细]
  • VScode格式化文档换行或不换行的设置方法
    本文介绍了在VScode中设置格式化文档换行或不换行的方法,包括使用插件和修改settings.json文件的内容。详细步骤为:找到settings.json文件,将其中的代码替换为指定的代码。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • Android Studio Bumblebee | 2021.1.1(大黄蜂版本使用介绍)
    本文介绍了Android Studio Bumblebee | 2021.1.1(大黄蜂版本)的使用方法和相关知识,包括Gradle的介绍、设备管理器的配置、无线调试、新版本问题等内容。同时还提供了更新版本的下载地址和启动页面截图。 ... [详细]
  • 使用在线工具jsonschema2pojo根据json生成java对象
    本文介绍了使用在线工具jsonschema2pojo根据json生成java对象的方法。通过该工具,用户只需将json字符串复制到输入框中,即可自动将其转换成java对象。该工具还能解析列表式的json数据,并将嵌套在内层的对象也解析出来。本文以请求github的api为例,展示了使用该工具的步骤和效果。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 本文介绍了C函数ispunct()的用法及示例代码。ispunct()函数用于检查传递的字符是否是标点符号,如果是标点符号则返回非零值,否则返回零。示例代码演示了如何使用ispunct()函数来判断字符是否为标点符号。 ... [详细]
  • 《数据结构》学习笔记3——串匹配算法性能评估
    本文主要讨论串匹配算法的性能评估,包括模式匹配、字符种类数量、算法复杂度等内容。通过借助C++中的头文件和库,可以实现对串的匹配操作。其中蛮力算法的复杂度为O(m*n),通过随机取出长度为m的子串作为模式P,在文本T中进行匹配,统计平均复杂度。对于成功和失败的匹配分别进行测试,分析其平均复杂度。详情请参考相关学习资源。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了在Windows环境下如何配置php+apache环境,包括下载php7和apache2.4、安装vc2015运行时环境、启动php7和apache2.4等步骤。希望对需要搭建php7环境的读者有一定的参考价值。摘要长度为169字。 ... [详细]
author-avatar
mobiledu2502856973
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有